local super = require "Graph"

XYGraph = super:new()

local defaults = {
    regressionEquations = true,
    regressionR2s = true
}

local nilDefaults = {
}

local handles = {
    AxisHandle:new{
        actionName = "Adjust Axes",
        token = Hook:new(
            function(self, x, y)
                local rect = self:getContentRect()
                local xaxis = self:getHorizontalAxis()
                local yaxis = self:getVerticalAxis()
                local tx = xaxis:scaled(rect, x)
                local ty = yaxis:scaled(rect, y)
                return tx, ty
            end,
            function(self, tx, ty, x, y)
                local rect = self:getContentRect()
                if tx then
                    local axis = self:getHorizontalAxis()
                    local draggingPosition = axis:scale(rect, tx)
                    local minValue = axis:scaled(rect, rect:minx() + draggingPosition - x)
                    local maxValue = axis:scaled(rect, rect:maxx() + draggingPosition - x)
                    axis:setRange(minValue, maxValue)
                end
                if ty then
                    local axis = self:getVerticalAxis()
                    local draggingPosition = axis:scale(rect, ty)
                    local minValue = axis:scaled(rect, rect:miny() + draggingPosition - y)
                    local maxValue = axis:scaled(rect, rect:maxy() + draggingPosition - y)
                    axis:setRange(minValue, maxValue)
                end
            end),
        track = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getVerticalAxis()
                local originPosition = axis:scale(rect, axis:origin())
                local axisPosition = math._mid(rect.bottom, originPosition, rect.top)
                return rect.left, axisPosition, rect.right, axisPosition
            end,
            nil),
        trackThickness = Hook:new(
            function(self)
                return 8, true
            end,
            nil),
        tokenPositions = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getHorizontalAxis()
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition = rect:minx(), rect:maxx()
                local crossingValue = originValue
                if originPosition <= minPosition then
                    crossingValue = axis:scaled(rect, minPosition)
                elseif originPosition >= maxPosition then
                    crossingValue = axis:scaled(rect, maxPosition)
                end
                local majorValues, majorPositions = axis:distribute(rect, crossingValue)
                local tokens, positions = {}, {}
                for index = 1, #majorValues do
                    if majorValues[index] ~= originValue then
                        tokens[#tokens + 1] = majorValues[index]
                        positions[#positions + 1] = majorPositions[index]
                    end
                end
                return tokens, positions
            end,
            function(self, token, position)
                local rect = self:getContentRect()
                local axis = self:getHorizontalAxis()
                local oldPosition = axis:scale(rect, token)
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition = rect:minx(), rect:maxx()
                local axisPosition = math._mid(minPosition, originPosition, maxPosition)
                local ratio
                if (oldPosition < axisPosition and position < axisPosition) or (oldPosition > axisPosition and position > axisPosition) then
                    ratio = (oldPosition - axisPosition) / (position - axisPosition)
                else
                    ratio = math.huge
                end
                ratio = math.min(ratio, 20 * math.abs(oldPosition - axisPosition) / (maxPosition - minPosition))
                local maxValue = axis:scaled(rect, axisPosition + ratio * (maxPosition - axisPosition))
                local minValue = axis:scaled(rect, axisPosition - ratio * (axisPosition - minPosition))
                axis:setRange(minValue, maxValue)
            end),
    },
    AxisHandle:new{
        actionName = "Adjust Axes",
        token = Hook:new(
            function(self, x, y)
                local rect = self:getContentRect()
                local xaxis = self:getHorizontalAxis()
                local yaxis = self:getVerticalAxis()
                local tx = xaxis:scaled(rect, x)
                local ty = yaxis:scaled(rect, y)
                return tx, ty
            end,
            function(self, tx, ty, x, y)
                local rect = self:getContentRect()
                if tx then
                    local axis = self:getHorizontalAxis()
                    local draggingPosition = axis:scale(rect, tx)
                    local minValue = axis:scaled(rect, rect:minx() + draggingPosition - x)
                    local maxValue = axis:scaled(rect, rect:maxx() + draggingPosition - x)
                    axis:setRange(minValue, maxValue)
                end
                if ty then
                    local axis = self:getVerticalAxis()
                    local draggingPosition = axis:scale(rect, ty)
                    local minValue = axis:scaled(rect, rect:miny() + draggingPosition - y)
                    local maxValue = axis:scaled(rect, rect:maxy() + draggingPosition - y)
                    axis:setRange(minValue, maxValue)
                end
            end),
        track = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getHorizontalAxis()
                local originPosition = axis:scale(rect, axis:origin())
                local axisPosition = math._mid(rect.left, originPosition, rect.right)
                return axisPosition, rect.bottom, axisPosition, rect.top
            end,
            nil),
        trackThickness = Hook:new(
            function(self)
                return 8, true
            end,
            nil),
        tokenPositions = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getVerticalAxis()
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition = rect:miny(), rect:maxy()
                local crossingValue = originValue
                if originPosition <= minPosition then
                    crossingValue = axis:scaled(rect, minPosition)
                elseif originPosition >= maxPosition then
                    crossingValue = axis:scaled(rect, maxPosition)
                end
                local majorValues, majorPositions = axis:distribute(rect, crossingValue)
                local tokens, positions = {}, {}
                for index = 1, #majorValues do
                    if majorValues[index] ~= originValue then
                        tokens[#tokens + 1] = majorValues[index]
                        positions[#positions + 1] = majorPositions[index]
                    end
                end
                return tokens, positions
            end,
            function(self, token, position)
                local rect = self:getContentRect()
                local axis = self:getVerticalAxis()
                local oldPosition = axis:scale(rect, token)
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition = rect:miny(), rect:maxy()
                local axisPosition = math._mid(minPosition, originPosition, maxPosition)
                local ratio
                if (oldPosition < axisPosition and position < axisPosition) or (oldPosition > axisPosition and position > axisPosition) then
                    ratio = (oldPosition - axisPosition) / (position - axisPosition)
                else
                    ratio = math.huge
                end
                ratio = math.min(ratio, 20 * math.abs(oldPosition - axisPosition) / (maxPosition - minPosition))
                local maxValue = axis:scaled(rect, axisPosition + ratio * (maxPosition - axisPosition))
                local minValue = axis:scaled(rect, axisPosition - ratio * (axisPosition - minPosition))
                axis:setRange(minValue, maxValue)
            end),
    },
}

function XYGraph:new()
    self = super.new(self)
    
    for k, v in pairs(defaults) do
        self:addProperty(k, v)
    end
    for _, k in pairs(nilDefaults) do
        self:addProperty(k)
    end
    
    local layerList = self:getLayerList()
    self._xyAddLayerObserver = function(item)
        if not item:getDataset() then
            local dataset = layerList:vote(function(layer)
                local dataSource = layer:getDataset()
                if dataSource then
                    return dataSource:getDataset()
                end
            end)
            local dataSource = DataSource:new(dataset)
            unarchived(dataSource)
            item:setDataset(dataSource)
        end
    end
    layerList:addEventObserver('add', self._xyAddLayerObserver)
    
    return self
end

function XYGraph:getHandles()
    return appendtables({}, handles, super.getHandles(self))
end

function XYGraph:getFontInspectors()
    local list = super.getFontInspectors(self)
    list:add(self:createFontInspector(TypographyScheme.quantityFont, 'Axis Values'))
    list:add(self:createFontInspector(TypographyScheme.labelFont, 'Text Series'))
    return list
end

local function drawThumbnail(canvas, rect, fonts, paints)
    if paints.background then
        canvas:setPaint(paints.background)
            :fill(Path.rect(canvas:metrics():rect(), 3))
    end
    local PADX, PADY = 4, 2
    local hundred = StyledString.new('100', { font = fonts.axisValue })
    local valueRect = hundred:measure()
    rect = rect:inset{
        left = valueRect:width() + PADX,
        bottom = 0,
        right = valueRect:width() * 2 / 3,
        top = 0,
    }
    if fonts.title and paints.title then
        canvas:setPaint(paints.title)
            :setFont(fonts.title)
            :drawText('Title', rect:midx(), rect:maxy() - fonts.title:ascent(), 0.5)
        rect = rect:inset{ left = 0, bottom = 0, right = 0, top = fonts.title:ascent() + fonts.title:descent() + PADY }
    end
    if fonts.axisTitle and paints.axisTitle then
        canvas:setPaint(paints.axisTitle)
            :setFont(fonts.axisTitle)
            :drawText('Axis Title', rect:midx(), rect:miny() + fonts.axisTitle:descent(), 0.5)
        rect = rect:inset{ left = 0, bottom = fonts.axisTitle:ascent() + fonts.axisTitle:descent() + PADY, right = 0, top = 0 }
    end
    rect = rect:inset{ left = 0, bottom = valueRect:height() + PADY, right = 0, top = valueRect:height() / 2 }
    canvas:setPaint(paints.fill)
        :fill(Path.rect(rect))
    local valueHeight = valueRect:height()
    local minx, miny = rect:minx(), rect:miny()
    local width, height = rect:width(), rect:height()
    local labelX, labelY = minx - PADX, miny - PADY - valueRect:maxy()
    canvas:setPaint(paints.label)
        :setFont(fonts.axisValue)
        :drawText('0', labelX, miny - valueHeight / 2 - valueRect:miny(), 1)
        :drawText('50', labelX, miny - valueHeight / 2 - valueRect:miny() + height * 1 / 2, 1)
        :drawText(hundred, labelX, miny - valueHeight / 2 - valueRect:miny() + height, 1)
        :drawText('0', minx, labelY, 0.5)
        :drawText('1', minx + width * 1 / 3, labelY, 0.5)
        :drawText('2', minx + width * 2 / 3, labelY, 0.5)
        :drawText('3', minx + width, labelY, 0.5)
    local data = {
        point = {
            { x = 0.33, y = 0.50 },
        },
        series1 = {
            { x = 0.09, y = 0.14 },
            { x = 0.13, y = 0.18 },
            { x = 0.17, y = 0.30 },
            { x = 0.21, y = 0.31 },
            { x = 0.25, y = 0.42 },
            { x = 0.29, y = 0.52 },
            { x = 0.33, y = 0.50 },
        },
        series2 = {
            { x = 0.30, y = 0.20 },
            { x = 0.35, y = 0.24 },
            { x = 0.40, y = 0.34 },
            { x = 0.45, y = 0.32 },
            { x = 0.50, y = 0.47 },
            { x = 0.55, y = 0.60 },
            { x = 0.60, y = 0.78 },
        },
        series3 = {
            { x = 0.48, y = 0.12 },
            { x = 0.54, y = 0.22 },
            { x = 0.60, y = 0.29 },
            { x = 0.66, y = 0.26 },
            { x = 0.72, y = 0.28 },
            { x = 0.78, y = 0.38 },
            { x = 0.84, y = 0.36 },
        },
    }
    for series, fractions in pairs(data) do
        if paints[series] then
            canvas:setPaint(paints[series])
            for index = 1, #fractions do
                local fraction = fractions[index]
                local pointX, pointY = rect:mapPoint(fraction.x, fraction.y)
                canvas:fill(Path.oval{ left = pointX - 2, bottom = pointY - 2, right = pointX + 2, top = pointY + 2 })
                if series == 'point' then
                    canvas:setPaint(paints.label)
                        :setFont(fonts.label)
                        :drawText('Label', pointX + 2 + PADX, pointY + (fonts.label:descent() - fonts.label:ascent()) / 2, 0)
                end
            end
        end
    end
    canvas:setPaint(paints.axis)
        :stroke(Path.line{ x1 = minx, x2 = minx, y1 = miny + height, y2 = miny }:addLine{ x = minx + width, y = miny })
end

function XYGraph:drawTypographySchemePreview(canvas, rect, typographyScheme)
    local SIZE = 12
    local fonts = {
        title = typographyScheme:getFont(TypographyScheme.titleFont, SIZE),
        axisTitle = typographyScheme:getFont(TypographyScheme.subtitleFont, SIZE),
        axisValue = typographyScheme:getFont(TypographyScheme.quantityFont, SIZE),
        label = typographyScheme:getFont(TypographyScheme.labelFont, SIZE),
    }
    local paints = {
        fill = Color.invisible,
        title = Color.gray(0, 1),
        axisTitle = Color.gray(0, 1),
        label = Color.gray(0, 1),
        axis = Color.gray(0, 0.4),
        point = Color.gray(0, 0.4),
    }
    drawThumbnail(canvas, rect, fonts, paints)
end

function XYGraph:drawColorSchemePreview(canvas, rect, colorScheme)
    local SIZE = 12
    local typographyScheme = self:getTypographyScheme()
    local fonts = {
        title = typographyScheme:getFont(TypographyScheme.titleFont, SIZE),
        axisValue = typographyScheme:getFont(TypographyScheme.quantityFont, SIZE),
        label = typographyScheme:getFont(TypographyScheme.labelFont, SIZE),
    }
    local paints = {
        background = colorScheme:getPaint(ColorScheme.pageBackgroundPaint),
        fill = colorScheme:getPaint(ColorScheme.backgroundPaint),
        title = colorScheme:getPaint(ColorScheme.titlePaint),
        label = colorScheme:getPaint(ColorScheme.labelPaint),
        axis = colorScheme:getPaint(ColorScheme.strokePaint),
        series1 = colorScheme:getDataSeriesPaint(1, 3),
        series2 = colorScheme:getDataSeriesPaint(2, 3),
        series3 = colorScheme:getDataSeriesPaint(3, 3),
    }
    drawThumbnail(canvas, rect, fonts, paints)
end

function XYGraph:getHorizontalAxisDescription()
    return 'X'
end

function XYGraph:getVerticalAxisDescription()
    return 'Y'
end

function XYGraph:showRegressionEquations()
    return self:getProperty('regressionEquations')
end

function XYGraph:getShowRegressionEquationsHook()
    return self:getPropertyHook('regressionEquations')
end

function XYGraph:showRegressionR2s()
    return self:getProperty('regressionR2s')
end

function XYGraph:getShowRegressionR2sHook()
    return self:getPropertyHook('regressionR2s')
end

return XYGraph
